{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecd852b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee96100",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import os\n",
    "import torch\n",
    "import torchaudio\n",
    "from pathlib import Path\n",
    "import webdataset as wds\n",
    "from contextlib import contextmanager\n",
    "\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b2c09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def shard_glob(input):\n",
    "    if '{' in input:\n",
    "        return wds.shardlists.expand_urls(input)\n",
    "    if isinstance(input, (Path, str)):\n",
    "        path = Path(input)\n",
    "        if path.is_dir():\n",
    "            glob = '*.tar.gz'\n",
    "        else:\n",
    "            glob = path.name\n",
    "            path = path.parent\n",
    "        input = Path(path).glob(glob)\n",
    "    else:\n",
    "        raise ArgumentError(\"input should be either a list or a path with an optional glob specifier\")\n",
    "    return [str(x) for x in input]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e427663b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class join_datasets(torch.utils.data.IterableDataset):\n",
    "    def __init__(self, datasets):\n",
    "        self.datasets = datasets\n",
    "        \n",
    "    def __iter__(self):\n",
    "        probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)\n",
    "        its = [iter(ds) for ds in self.datasets]\n",
    "        while True:\n",
    "            try:\n",
    "                yield next(its[torch.multinomial(probs, 1)])\n",
    "            except StopIteration:\n",
    "                return    \n",
    "    \n",
    "    def __len__(self):\n",
    "        return sum([ds.total_samples for ds in self.datasets])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b82aaf10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "a\n",
      "1\n",
      "2\n",
      "3\n",
      "A\n",
      "4\n",
      "5\n",
      "b\n",
      "B\n",
      "c\n",
      "C\n",
      "D\n",
      "E\n",
      "6\n",
      "d\n",
      "e\n",
      "7\n",
      "F\n",
      "f\n",
      "g\n",
      "8\n",
      "G\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "# will stop as soon as it exhausts one iterator\n",
    "for x in join_datasets(['ABCDEFG', 'abcdefg', range(20)]):\n",
    "    print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f72d7e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def resampler(newsr = 24000, key = 'samples_24k'):\n",
    "    _last_sr = None\n",
    "    tform = None\n",
    "    \n",
    "    def _resample(samples):\n",
    "        for s in samples:\n",
    "            sr = s['sample_rate']\n",
    "            if sr != newsr:\n",
    "                if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)\n",
    "                s[key] = tform(s['samples'])\n",
    "            else:\n",
    "                s[key] = s['samples']\n",
    "            yield s\n",
    "    \n",
    "    return _resample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2cc4bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def derived_name(input, kind, base=\"audio\", suffix=\".gz\", dir=None):\n",
    "    dir = Path(dir) if dir else Path(input).parent\n",
    "    return str(dir/(Path(input).name.replace(f\"-{base}-\", f\"-{kind}-\") + suffix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d93ebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def derived_dataset(kind, base='audio', suffix=\".gz\", decoders=[], dir=None):\n",
    "    def deriver(url):\n",
    "        url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))\n",
    "        return wds.WebDataset(\n",
    "            wds.SimpleShardList([url])\n",
    "        ).decode(*decoders)\n",
    "    return deriver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6126400c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def merge_in(dataset_fun):\n",
    "    \"\"\"Merge a dataset into the current one returning samples with the union of keys. Pass in a function\n",
    "    that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).\n",
    "    \n",
    "    It requires (and validates) that both datasets have the same ordering of keys so you have\n",
    "    to use it before any sample shuffling. Shard shuffling is ok.\n",
    "    \"\"\"\n",
    "    def merge_loop(main_samples):\n",
    "        #print(\"new merge loop:\", dataset_fun)\n",
    "        merged_samples = None\n",
    "        cur_url = None\n",
    "        i = None\n",
    "        for s in main_samples:\n",
    "            url = s['__url__']\n",
    "            if url != cur_url:\n",
    "                # this will open a new file when we get the first sample with a new __url__\n",
    "                merged_samples = iter(dataset_fun(url))\n",
    "                cur_url = url\n",
    "            try:\n",
    "                merge_s = next(merged_samples)\n",
    "            except StopIteration:\n",
    "                # if the original shard got repeated we won't observe a __url__ change\n",
    "                # in this case restart the dataset from the beginning\n",
    "                merged_samples = iter(dataset_fun(url))\n",
    "                merge_s = next(merged_samples)\n",
    "            assert merge_s['__key__'] == s['__key__'], f\"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}\"\n",
    "            news = {}\n",
    "            news.update(merge_s)\n",
    "            news.update(s)\n",
    "            yield news\n",
    "    return merge_loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60414cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):\n",
    "    for s in stream:\n",
    "        audio, sr = s['audio']\n",
    "        imax = len(s[ikey]) - 1\n",
    "        for i,(ts,te) in enumerate(s[ikey]):\n",
    "            samples = audio[0,int(ts*sr):int(te*sr)]\n",
    "            if pad_to_seconds is not None:\n",
    "                padding = pad_to_seconds*sr-samples.shape[-1]\n",
    "                lpad = random.randint(0, padding) if random_shift else 0\n",
    "                samples = F.pad(samples, (lpad, padding-lpad))\n",
    "            subs = {\"__key__\": s['__key__'] + f\"_{i:03d}\",\n",
    "                    \"src_key\": s['__key__'],\n",
    "                    \"__url__\": s['__url__'],\n",
    "                    \"i\": i, \"imax\": imax,\n",
    "                    \"tstart\": ts, \"tend\": te, \"total_seconds\": audio.shape[-1]/sr,\n",
    "                    \"lpad\": lpad, \"rpad\": padding-lpad,\n",
    "                    \"lpad_s\": lpad/sr, \"rpad_s\": (padding-lpad)/sr,\n",
    "                    \"samples\": samples, \"sample_rate\": sr}\n",
    "            for k in metakeys:\n",
    "                subs[k] = s[k][i]\n",
    "            yield subs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ba34c6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def vad_dataset(shards, ikey='vad.npy', kind='vad'):\n",
    "    return wds.WebDataset(shards).compose(\n",
    "        wds.decode(wds.torch_audio),\n",
    "        merge_in(derived_dataset(kind)),\n",
    "        wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio\n",
    "        wds.rename(audio=\"flac;mp3;wav;ogg\"),\n",
    "        lambda x: split_to_chunks(x, ikey=ikey),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e69c50e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@contextmanager\n",
    "def AtomicTarWriter(name, throwaway=False):\n",
    "    tmp = name+\".tmp\"\n",
    "    with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:\n",
    "        yield sink\n",
    "    if not throwaway:\n",
    "        os.rename(tmp, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68eaed68",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def readlines(fname):\n",
    "    with open(fname) as file:\n",
    "        return [line.rstrip() for line in file]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
